Skip to content

[Feature][Performance] NextObservationDelta env transform#3777

Open
vmoens wants to merge 4 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta
Open

[Feature][Performance] NextObservationDelta env transform#3777
vmoens wants to merge 4 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 18, 2026

Summary

  • Adds NextObservationDelta, a stateless env-side transform that stores ("next", obs) as a low-precision delta from the root obs for rollout memory savings on large continuous observations.
  • Wires up the previously-stubbed _post_step_mdp_hooks extension point in EnvBase.step_and_maybe_reset and threads it through Transform, Compose, and TransformedEnv. The hook receives both the post-step and post-step-mdp tensordicts so a transform can rehydrate the flowing td that the policy reads on the next iteration.
  • NextObservationDelta._step writes (next_obs - obs).to(delta_dtype) (default float16); _post_step_mdp_hooks reconstructs obs + delta in restore_dtype (default: match root). Stateless — no caching across steps.

Why this shape

The existing compact_obs collector flag + NextStateReconstructor RB transform attack the same problem by dropping ("next", obs) entirely and shifting at sample time. That is zero-storage but lossy at trajectory boundaries (which become NaN). The delta variant trades a small precision loss for boundary-preserving reconstruction and an env-side hook that does not need to know about collector internals.

The _post_step_mdp_hooks mechanism was already stubbed (commented out) in common.py, transforms/_base.py, and llm/chat.py. This PR enables it. The signature was changed from the original comment ((tensordict_,) -> tensordict_) to (tensordict, tensordict_) -> tensordict_ because rehydration needs read access to the post-step root obs. No caller existed before, so this is not a breaking change.

v1 limitations (documented on the class)

  • Lossy. Round-trip error scales with delta_dtype precision and observation magnitude.
  • Memory savings require non-pre-allocated stacked output. SyncDataCollector(use_buffers=False) or a lazy RB storage. Pre-allocated _final_rollout upcasts the write back to the original dtype and erases the saving.
  • Hook fires from step_and_maybe_reset only. env.rollout() is not wired in v1; direct rollout callers must rehydrate manually.
  • check_env_specs does not pass on the transformed env. observation_spec is shared between root and ("next", ...) in TorchRL; the transform does not fork it in v1 (a follow-up could). Tests use a reset+step smoke instead.
  • Batched-env composition. For SerialEnv/ParallelEnv, the transform belongs outside the batched env (i.e. TransformedEnv(ParallelEnv(...), NextObservationDelta())) — that path uses the outer step_and_maybe_reset and the hook fires. Putting the transform inside each worker is allowed and runs without error, but the outer batched env's step_and_maybe_reset does not currently propagate the hook so the stacked output upcasts.

Out of scope (potential follow-ups)

  • Forking observation_spec so pre-allocated _final_rollout benefits from the compression.
  • Wiring the hook in _rollout_stop_early and in batched_envs / async_envs / envpool step_and_maybe_reset.
  • A replay-buffer-side delta transform paired with this one.
  • Benchmark entry under benchmarks/.

Test plan

  • pytest test/transforms/test_observation_transforms.py::TestNextObservationDelta — 14 passed, 2 documented skips.
  • pytest --doctest-modules torchrl/envs/transforms/_observation.py -k NextObservationDelta — passes.
  • pytest test/envs/test_env_base.py — 47 passed, 4 skipped (no regressions from the hook wiring).
  • Manual smoke against GymEnv("Pendulum-v1") confirms ("next", "observation").dtype == torch.float16 post-step and torch.float32 on the flowing td, with bitwise-exact rehydration (max diff 0.0).
  • Compose(NextObservationDelta, RewardSum) works in both orderings.
  • Wider CI sweep (compose + env-transforms suites) — local disk filled before completing; relying on CI.

Adds a stateless env-side transform that stores `("next", obs)` as a
low-precision delta from the root `obs`, reducing the rollout-time
memory footprint of large continuous observations.

The transform compresses next observations in `_step` and rehydrates
the flowing tensordict's root observation in a new
`_post_step_mdp_hooks` extension point on `EnvBase`. The hook was
previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`;
it is now wired through `step_and_maybe_reset` and threaded into
`Transform`, `Compose`, and `TransformedEnv`.

Caveats documented on the class:

- The compression is lossy; round-trip error scales with delta dtype
  precision and observation magnitude.
- Memory savings only materialize against non-pre-allocated stacked
  output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB
  storage). Pre-allocated buffers upcast the write.
- The hook fires from `step_and_maybe_reset`; direct `env.rollout()`
  callers must rehydrate manually.
- `check_env_specs` rejects the transformed env in v1 because the
  observation spec is shared between root and `("next", ...)` and we
  do not fork it.

Includes a `TestNextObservationDelta` test class with 16 cases
(14 passing, 2 documented skips) covering single-env, serial/parallel
batched envs (inner and outer wrapping), auto-inference skipping
non-floating dtypes, multi-key, reset semantics, Compose ordering,
and an end-to-end `SyncDataCollector(use_buffers=False)` check that
the stacked batch carries `float16` `("next", obs)`.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 18, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3777

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit 0ef4983 with merge base 996387f (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 18, 2026
@github-actions github-actions Bot added Documentation Improvements or additions to documentation Transforms Feature New feature labels May 18, 2026
vmoens added 3 commits May 18, 2026 18:05
- Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so
  `env.rollout(..., break_when_any_done=True)` rehydrates the flowing
  td just like `step_and_maybe_reset` already did. The non-stop path
  already routed through `step_and_maybe_reset` and is unchanged.

- Add `Transform.transform_fake_tensordict(td)` hook (no-op default),
  iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict`
  override. `NextObservationDelta` overrides it to cast the
  `("next", key)` leaves to `delta_dtype` in the spec-derived fake td.
  Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)`
  now reserves storage at the compressed dtype rather than upcasting
  writes; the collector test covers both `use_buffers={True, False}`.

- Add `Transform._check_batched_worker_compat()` (no-op default).
  `NextObservationDelta` raises with a clear message pointing at the
  correct usage pattern. `BatchedEnvBase._get_metadata` builds a
  transient probe env and runs the validator via a new `env_validator`
  kwarg on `get_env_metadata`, so the inner-batched configuration
  fails loudly at construction time instead of silently upcasting at
  runtime.

The remaining v1 caveat in the docstring is that `check_env_specs`
still does not pass: it calls `observation_spec.contains(("next", obs))`
and TorchRL shares `observation_spec` between root and `("next", ...)`
leaves, so a compressed dtype is rejected. Working around this
properly requires forking the spec system, which is out of scope for
this PR. Tests use a reset+step smoke instead.
Subtracting in delta_dtype (float16 by default) risks catastrophic
cancellation when next_obs and obs are close. Doing the subtraction
in the operands' source dtype and casting the result once preserves
significand bits and is strictly more accurate on round-trip.

The stored root obs is unchanged, so there is no asymmetry to
preserve between the on-the-fly delta and the value reconstructed
from storage.
@vmoens
Copy link
Copy Markdown
Collaborator Author

vmoens commented May 19, 2026

@elin-bdai @theap06 maybe you could help review this one?
It reduces the size of ("next", "observation") fields by 50% when the transform is applied by using low-prec representation of the delta between o_t and o'_t
That looks like a good compromise between efficiency and correctness to carry the last obs of a trajectory/rollout.

Copy link
Copy Markdown
Contributor

@elin-bdai elin-bdai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this! I'm going to test this with our longer training jobs this week to make sure loss of precision is not a problem. Just a comment in terms of reducing confusion.

>>> td_root = env.reset()
>>> _ = td_root.set("action", env.action_spec.rand())
>>> td, td_ = env.step_and_maybe_reset(td_root)
>>> td["next", "observation"].dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding correctly, I think it's confusing here when using NextObservationDelta() that what's inside td["next", "observation"] is the delta, but the tensordict is indistinguishable from when you don't use NextObservationDelta, so you're not sure if it's the delta or not stored in there. It could lead to confusion when inspecting the outputs at different points.

# operand to ``delta_dtype`` first and subtracting in low precision
# (which would risk catastrophic cancellation for nearby values).
delta = (next_obs - obs).to(self.delta_dtype)
next_tensordict.set(key, delta)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to change the key here to {key}_delta?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants